﻿using PetaPoco;
using PetaPoco.Core;
using PetaPoco.Providers;
using PetaPoco.Utilities;
using PmSoft.Core.Domain.Entities;
using PmSoft.Data.Abstractions;

namespace PmSoft.Data.PetaPoco;

public static class PetaPocoDatabaseExtensions
{
	#region 获取前 topNumber 条记录

	/// <summary>
	/// 获取前 topNumber 条记录的主键集合
	/// </summary>
	public static IEnumerable<TKey> FetchTopPrimaryKeys<TEntity, TKey>(this Database database, int topNumber, Sql sql)
		where TEntity : IEntity<TKey>, new() where TKey : notnull
	{
		var sqlArguments = sql.Arguments;
		var topSql = database.BuildTopSql<TEntity>(topNumber, sql.SQL);
		return database.FetchFirstColumn<TKey>(topSql, sqlArguments);
	}

	/// <summary>
	/// 异步获取前 topNumber 条记录的主键集合
	/// </summary>
	public static async Task<IEnumerable<TKey>> FetchTopPrimaryKeysAsync<TEntity, TKey>(this Database database, int topNumber, Sql sql)
		where TEntity : IEntity<TKey>, new() where TKey : notnull
	{
		var sqlArguments = sql.Arguments;
		var topSql = database.BuildTopSql<TEntity>(topNumber, sql.SQL);
		return await database.FetchFirstColumnAsync<TKey>(topSql, sqlArguments).ConfigureAwait(false);
	}

	private static string BuildTopSql<T>(this Database database, int topNumber, string sql)
	{
		var pocoData = PocoData.ForType(typeof(T), database.DefaultMapper);
		var primaryKey = $"{pocoData.TableInfo.TableName}.{pocoData.TableInfo.PrimaryKey}";
		if (database.EnableAutoSelect)
		{
			sql = database.Provider.AddSelectClause<T>(sql, database.DefaultMapper, primaryKey);
		}
		return database.BuildTopSql(topNumber, sql);
	}

	private static string BuildTopSql(this Database database, int topNumber, string sql)
	{
		var match = ((PagingHelper)database.Provider.PagingUtility).RegexColumns.Match(sql);
		if (!match.Success)
			return string.Empty;

		var columnsGroup = match.Groups[1];
		return database.Provider is not SqlServerDatabaseProvider
			? $"{sql.Substring(0, columnsGroup.Index)} {columnsGroup.Value} {sql.Substring(columnsGroup.Index + columnsGroup.Length)} limit {topNumber}"
			: $"{sql.Substring(0, columnsGroup.Index)} top {topNumber} {columnsGroup.Value} {sql.Substring(columnsGroup.Index + columnsGroup.Length)}";
	}

	#endregion

	#region 根据主键值查询集合

	/// <summary>
	/// 根据主键值查询实体集合（自定义单次最大主键数量）
	/// </summary>
	/// <param name="batchSize">单次查询允许的最大主键数量（建议根据数据库特性设置，如SQL Server设为2000）</param>
	public static IEnumerable<TEntity> FetchByPrimaryKeys<TEntity, TKey>(
		this Database database,
		IEnumerable<TKey> primaryKeys,
		int batchSize = 1000)
			where TKey : notnull
	{
		// 参数校验
		if (primaryKeys == null)
			throw new ArgumentNullException(nameof(primaryKeys));
		if (batchSize <= 0)
			throw new ArgumentOutOfRangeException(nameof(batchSize), "批次大小必须大于0");

		// 处理空集合
		var distinctKeys = primaryKeys.Distinct().ToList();
		if (distinctKeys.Count == 0)
			return Enumerable.Empty<TEntity>();

		// 分批次查询
		var batchKeysList = distinctKeys
			.Select((key, index) => new { key, index })
			.GroupBy(x => x.index / batchSize)
			.Select(g => g.Select(x => x.key).ToList())
			.ToList();

		// 执行所有批次查询并合并结果
		var results = batchKeysList
			.Select(batchKeys => database.FetchByPrimaryKeys<TEntity, TKey>(batchKeys))
			.SelectMany(x => x);

		return results;
	}

	/// <summary>
	/// 根据主键值查询实体集合
	/// </summary>
	private static IEnumerable<TEntity> FetchByPrimaryKeys<TEntity, TKey>(this Database database, IEnumerable<TKey> primaryKeys)
		where TKey : notnull
	{
		if (primaryKeys == null)
			throw new ArgumentNullException(nameof(primaryKeys));
		if (!primaryKeys.Any())
			return new List<TEntity>();

		var primaryKey = database.Provider.EscapeSqlIdentifier(PocoData.ForType(typeof(TEntity), database.DefaultMapper).TableInfo.PrimaryKey);
		var sql = Sql.Builder.Where($"{primaryKey} IN(@PrimaryKeys)", new { PrimaryKeys = primaryKeys });
		return database.Fetch<TEntity>(sql);
	}

	/// <summary>
	/// 异步根据主键值查询实体集合（自定义单次最大主键数量）
	/// </summary>
	/// <param name="batchSize">单次查询允许的最大主键数量（建议根据数据库特性设置，如SQL Server设为2000）</param>
	public static async Task<IEnumerable<TEntity>> FetchByPrimaryKeysAsync<TEntity, TKey>(
		this Database database,
		IEnumerable<TKey> primaryKeys,
		int batchSize = 1000)
			where TKey : notnull
	{
		// 参数校验
		if (primaryKeys == null)
			throw new ArgumentNullException(nameof(primaryKeys));
		if (batchSize <= 0)
			throw new ArgumentOutOfRangeException(nameof(batchSize), "批次大小必须大于0");

		// 处理空集合
		var distinctKeys = primaryKeys.Distinct().ToList();
		if (distinctKeys.Count == 0)
			return Enumerable.Empty<TEntity>();

		// 分批次查询
		var batchKeysList = distinctKeys
			.Select((key, index) => new { key, index })
			.GroupBy(x => x.index / batchSize)
			.Select(g => g.Select(x => x.key).ToList())
			.ToList();

		// 打开共享连接（所有批次共用同一连接）
		await database.OpenSharedConnectionAsync().ConfigureAwait(false);
		try
		{
			var results = new List<TEntity>();
			foreach (var batchKeys in batchKeysList)
			{
				var batchResults = await database.FetchByPrimaryKeysAsync<TEntity, TKey>(primaryKeys).ConfigureAwait(false);
				results.AddRange(batchResults);
			}
			return results;
		}
		finally
		{
			database.CloseSharedConnection();
		}
	}

	/// <summary>
	/// 异步根据主键值查询实体集合
	/// </summary>
	private static async Task<IEnumerable<TEntity>> FetchByPrimaryKeysAsync<TEntity, TKey>(this Database database, IEnumerable<TKey> primaryKeys)
		where TKey : notnull
	{
		if (primaryKeys == null)
			throw new ArgumentNullException(nameof(primaryKeys));
		if (!primaryKeys.Any())
			return new List<TEntity>();

		var primaryKey = database.Provider.EscapeSqlIdentifier(PocoData.ForType(typeof(TEntity), database.DefaultMapper).TableInfo.PrimaryKey);
		var sql = Sql.Builder.Where($"{primaryKey} IN(@PrimaryKeys)", new { PrimaryKeys = primaryKeys });
		return await database.FetchAsync<TEntity>(sql).ConfigureAwait(false);
	}

	#endregion

	#region 获取第一列组成的集合

	/// <summary>
	/// 获取查询结果的第一列数据
	/// </summary>
	public static IEnumerable<T> FetchFirstColumn<T>(this Database database, Sql sql)
	{
		return database.FetchFirstColumn<T>(sql.SQL, sql.Arguments);
	}

	/// <summary>
	/// 获取查询结果的第一列数据
	/// </summary>
	public static IEnumerable<T> FetchFirstColumn<T>(this Database database, string sql, params object[] args)
	{
		var resultList = new List<T>();
		try
		{
			database.OpenSharedConnection();
			using (var command = database.CreateCommand(database.Connection, sql, args))
			using (var reader = command.ExecuteReader())
			{
				database.OnExecutedCommand(command);
				while (reader.Read())
				{
					if (reader[0] is T value)
						resultList.Add(value);
				}
			}
		}
		finally
		{
			database.CloseSharedConnection();
		}
		return resultList;
	}

	/// <summary>
	/// 异步获取查询结果的第一列数据
	/// </summary>
	public static async Task<IEnumerable<T>> FetchFirstColumnAsync<T>(this Database database, Sql sql)
	{
		return await database.FetchFirstColumnAsync<T>(sql.SQL, sql.Arguments).ConfigureAwait(false);
	}

	/// <summary>
	/// 异步获取查询结果的第一列数据
	/// </summary>
	public static async Task<IEnumerable<T>> FetchFirstColumnAsync<T>(this Database database, string sql, params object[] args)
	{
		var resultList = new List<T>();
		try
		{
			await database.OpenSharedConnectionAsync().ConfigureAwait(false);
			using (var command = database.CreateCommand(database.Connection, sql, args))
			using (var reader = command.ExecuteReader())
			{
				database.OnExecutedCommand(command);
				while (reader.Read())
				{
					if (reader[0] is T value)
						resultList.Add(value);
				}
			}
		}
		finally
		{
			database.CloseSharedConnection();
		}
		return resultList;
	}

	#endregion

	#region 分页的主键集合

	/// <summary>
	/// 分页获取主键集合
	/// </summary>
	public static PagedEntityIdsCollection<TKey> FetchPagingPrimaryKeys<TEntity, TKey>(this Database database, long maxRecords, int pageSize, int pageIndex, Sql sql)
		where TKey : notnull
	{
		var sqlQuery = sql.SQL;
		var sqlArguments = sql.Arguments;
		database.BuildPagingPrimaryKeyQueries<TEntity>(maxRecords, (pageIndex - 1) * pageSize, pageSize, sqlQuery, ref sqlArguments, out var countSql, out var pagingSql);
		return new PagedEntityIdsCollection<TKey>(database.FetchFirstColumn<TKey>(pagingSql, sqlArguments), database.ExecuteScalar<int>(countSql, sqlArguments));
	}

	/// <summary>
	/// 异步分页获取主键集合
	/// </summary>
	public static async Task<PagedEntityIdsCollection<TKey>> FetchPagingPrimaryKeysAsync<TEntity, TKey>(this Database database, long maxRecords, int pageSize, int pageIndex, Sql sql)
		where TKey : notnull
	{
		var sqlQuery = sql.SQL;
		var sqlArguments = sql.Arguments;
		database.BuildPagingPrimaryKeyQueries<TEntity>(maxRecords, (pageIndex - 1) * pageSize, pageSize, sqlQuery, ref sqlArguments, out var countSql, out var pagingSql);
		return new PagedEntityIdsCollection<TKey>(await database.FetchFirstColumnAsync<TKey>(pagingSql, sqlArguments).ConfigureAwait(false), await database.ExecuteScalarAsync<int>(countSql, sqlArguments).ConfigureAwait(false));
	}

	private static void BuildPagingPrimaryKeyQueries<TEntity>(this Database database, long maxRecords, long skip, long take, string sql, ref object[] args, out string countSql, out string pagingSql)
	{
		var pocoData = PocoData.ForType(typeof(TEntity), database.DefaultMapper);
		var primaryKey = sql.Contains(pocoData.TableInfo.TableName) ? $"{pocoData.TableInfo.TableName}.{pocoData.TableInfo.PrimaryKey}" : pocoData.TableInfo.PrimaryKey;
		if (database.EnableAutoSelect)
		{
			sql = database.Provider.AddSelectClause<TEntity>(sql, database.DefaultMapper, primaryKey);
		}
		database.BuildPagingPrimaryKeyQueries(maxRecords, skip, take, primaryKey, sql, ref args, out countSql, out pagingSql);
	}

	private static void BuildPagingPrimaryKeyQueries(this Database database, long maxRecords, long skip, long take, string primaryKey, string sql, ref object[] args, out string countSql, out string pagingSql)
	{
		if (!database.SplitSqlForPagingOptimized(maxRecords, sql, primaryKey, out var sqlParts))
			throw new Exception("Unable to parse SQL statement for paged query");

		countSql = sqlParts.SqlCount;
		pagingSql = database.Provider.BuildPageQuery(skip, take, sqlParts, ref args);
	}

	private static bool SplitSqlForPagingOptimized(this Database database, long maxRecords, string sql, string primaryKey, out SQLParts sqlParts)
	{
		sqlParts.Sql = sql;
		sqlParts.SqlSelectRemoved = string.Empty;
		sqlParts.SqlCount = sql;
		sqlParts.SqlOrderBy = string.Empty;

		var pagingUtility = (PagingHelper)database.Provider.PagingUtility;
		var columnsMatch = pagingUtility.RegexColumns.Match(sql);
		if (!columnsMatch.Success)
			return false;

		var orderByMatch = pagingUtility.RegexOrderBy.Match(sql);
		if (orderByMatch.Success)
		{
			sqlParts.SqlOrderBy = orderByMatch.Value;
			sqlParts.SqlCount = sql.Replace(orderByMatch.Value, string.Empty);
		}

		var columnsGroup = columnsMatch.Groups[1];
		sqlParts.SqlSelectRemoved = sql.Substring(columnsGroup.Index);

		if (database.Provider is SqlServerDatabaseProvider)
		{
			sqlParts.SqlCount = $"select count(*) from ({sql.Substring(0, columnsGroup.Index)} top {maxRecords} {primaryKey} {sql.Substring(columnsGroup.Index + columnsGroup.Length)}) as TempCountTable";
		}
		else
		{
			if (pagingUtility.RegexDistinct.IsMatch(sqlParts.SqlSelectRemoved) || pagingUtility.SimpleRegexGroupBy.IsMatch(sqlParts.SqlSelectRemoved))
				sqlParts.SqlCount = $"{sql.Substring(0, columnsGroup.Index)} COUNT(*) FROM  ({sqlParts.SqlCount} limit {maxRecords}) countAlias";
			else
				sqlParts.SqlCount = $"select count(*) from ({sqlParts.SqlCount} limit {maxRecords}) as TempCountTable";
		}
		return true;
	}

	#endregion
}